Gemma/[Gemma_2]Gradio_Chatbot.ipynb (582 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "tlvY59v6PjvC" }, "source": [ "##### Copyright 2025 Google LLC." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "jiFUWa49Pl1F" }, "outputs": [], "source": [ "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "XjJkne25Prps" }, "source": [ "# Building a Chatbot with Gemma and Gradio\n", "\n", "<table align=\"left\">\n", " <td>\n", " <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Gradio_Chatbot.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n", " </td>\n", "</table>" ] }, { "cell_type": "markdown", "metadata": { "id": "QdM3rXG4U-mb" }, "source": [ "## Setup\n", "\n", "### Runtime Environment\n", "\n", " 1. Click **Open in Colab**.\n", " 2. In the menu, go to **Runtime** > **Change runtime type**.\n", " 3. Under **Hardware accelerator**, select **T4 GPU**.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "OCvGElGDDwnO" }, "source": [ "### Hugging Face Hub Access Token\n", "\n", "Before diving into the tutorial, let's set up Gemma:\n", "\n", "1. **Create a Hugging Face Account**: If you don't have one, you can sign up for a free account [here](https://huggingface.com/join).\n", "2. **Access the Gemma Model**: Visit the [Gemma model page](https://huggingface.com/collections/google/gemma-2-release-667d6600fd5220e7b967f315) and accept the usage terms.\n", "3. **Generate a Hugging Face Token**: Go to your Hugging Face [settings page](https://huggingface.com/settings/tokens) and generate a new access token (preferably with `write` permissions). You'll need this token later in the tutorial.\n", "\n", "**Once you've completed these steps, you're ready to move on to the next section, where you'll set up environment variables in your Colab environment.**" ] }, { "cell_type": "markdown", "metadata": { "id": "3pjmx0qYVI5_" }, "source": [ "### Configure Your Credentials\n", "\n", "To access private models and datasets, you need to log in to the Hugging Face (HF) ecosystem.\n", "\n", "If you're using Colab, you can securely store your Hugging Face token (`HF_TOKEN`) using the Colab Secrets Manager:\n", "1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n", "2. **Add Hugging Face Token**:\n", "- Create a new secret with the **name** `HF_TOKEN`.\n", "- Copy and paste your token key into the **Value** input box for `HF_TOKEN`.\n", "- **Toggle** the button on the left to allow notebook access to the secret" ] }, { "cell_type": "markdown", "metadata": { "id": "V7GXw4tjVS0H" }, "source": [ "This code retrieves your secrets and sets them as environment variables for use later in the tutorial." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "j6tFvOkyZz_o" }, "outputs": [], "source": [ "import os\n", "import sys\n", "\n", "if \"google.colab\" in sys.modules:\n", " from google.colab import userdata\n", " os.environ['HF_TOKEN'] = userdata.get(\"HF_TOKEN\")\n", "\n", "if \"HF_TOKEN\" not in os.environ:\n", " raise EnvironmentError(\n", " \"The Hugging Face token (HF_TOKEN) could not be found in the \"\n", " \"environment variables. This token is required to download the Gemma \"\n", " \"models from the Hugging Face Hub. For more information about \"\n", " \"HF User Access tokens, please refer to the HF documentation \"\n", " \"here: https://huggingface.co/docs/hub/en/security-tokens.\"\n", " )" ] }, { "cell_type": "markdown", "metadata": { "id": "_zS6gPp1VgPd" }, "source": [ "### Install dependencies\n", "\n", "Next, you'll install the required libraries. In this case, we only need gradio for the chat interface and transformers to load the Gemma model from the Hugging Face Hub.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bVaISQ8GZz_s" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.2/57.2 MB\u001b[0m \u001b[31m13.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m320.4/320.4 kB\u001b[0m \u001b[31m23.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m11.3/11.3 MB\u001b[0m \u001b[31m34.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m73.2/73.2 kB\u001b[0m \u001b[31m3.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.3/62.3 kB\u001b[0m \u001b[31m4.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.1/44.1 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.0/10.0 MB\u001b[0m \u001b[31m63.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.0/3.0 MB\u001b[0m \u001b[31m57.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h" ] } ], "source": [ "!pip install -q -U gradio==5.9.1\n", "!pip install -q -U transformers==4.46.1" ] }, { "cell_type": "markdown", "metadata": { "id": "l8_pGPDOtJ95" }, "source": [ "## Chat with Gemma using Gradio" ] }, { "cell_type": "markdown", "metadata": { "id": "KVyMsQc7WoIu" }, "source": [ "### Initializing Gemma 2 model\n", "\n", "Let's create a pipeline that will use the gemma-2-2b-it model to generate text. The transformers library provides an easy way to load the model and tokenizer into memory by simply specifying the model name and some basic parameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TFhbJSgsWR00" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6b4fc98e36104ba38ed72a68c6f77562", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer_config.json: 0%| | 0.00/47.0k [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5c50f9dcd070495c9e03d2b060dfac04", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer.model: 0%| | 0.00/4.24M [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9db1ded5a8fc47bb9ac64860f21bc603", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer.json: 0%| | 0.00/17.5M [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1bc4101d3dc94414b0c3a62715098123", "version_major": 2, "version_minor": 0 }, "text/plain": [ "special_tokens_map.json: 0%| | 0.00/636 [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fe6d456cc92e4847addbac8092de9c60", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config.json: 0%| | 0.00/838 [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a90c51330cf94f06b173738bd4e8c29b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model.safetensors.index.json: 0%| | 0.00/24.2k [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "500aa7ccdd8f469b9e4f7546314e42b4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "27503afa871949ac936db998d37e68e8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model-00001-of-00002.safetensors: 0%| | 0.00/4.99G [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "febf4ae84e5a416fadc7f45168dff966", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model-00002-of-00002.safetensors: 0%| | 0.00/241M [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "49efb4e381124a419650a86e1cb5d9dc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5d02c68861ac48299b4b311466462158", "version_major": 2, "version_minor": 0 }, "text/plain": [ "generation_config.json: 0%| | 0.00/187 [00:00<?, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import torch\n", "import transformers\n", "\n", "# Model details\n", "model_name = \"google/gemma-2-2b-it\"\n", "device = \"cuda\"\n", "model_kwargs = {\n", " \"torch_dtype\": torch.float16,\n", "}\n", "\n", "# Load the Gemma tokenizer\n", "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)\n", "\n", "# Create a pipeline\n", "pipeline = transformers.pipeline(\n", " \"text-generation\",\n", " model=model_name,\n", " device=device,\n", " model_kwargs=model_kwargs\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "WQunlnd6avP-" }, "source": [ "### Create a custom chat template\n", "\n", "Hugging Face supports chat templates that can define the structure and format for converting conversations into a single tokenizable string, which is the input format expected by the language model. Check the [chat templates documentation](https://huggingface.co/docs/transformers/main/en/chat_templating) to learn more about templates and how to create a custom one.\n", "\n", "Since Gemma doesn't support system instructions, you will provide system input as user input. This template has been adjusted to be compatible with Gradio's chat interface. To learn more about the format expected by Gemma, check out the [Gemma formatting documentation](https://ai.google.dev/gemma/docs/formatting)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sDTk1-bcZz_2" }, "outputs": [], "source": [ "tokenizer.chat_template = \\\n", " \"{{ bos_token }}\"\\\n", " \"{% if messages[0]['role'] == 'system' %}\"\\\n", " \"{{'<start_of_turn>user\\n' + messages[0]['content'] | trim + ' ' + messages[1]['content'] | trim + '<end_of_turn>\\n'}}\"\\\n", " \"{% set messages = messages[2:] %}\"\\\n", " \"{% endif %}\"\\\n", " \"{% for message in messages %}\"\\\n", " \"{% if message['role'] == 'user' %}\"\\\n", " \"{{'<start_of_turn>user\\n' + message['content'] | trim + '<end_of_turn>\\n'}}\"\\\n", " \"{% elif message['role'] == 'assistant' %}\"\\\n", " \"{{'<start_of_turn>model\\n' + message['content'] | trim + '<end_of_turn>\\n' }}\"\\\n", " \"{% endif %}\"\\\n", " \"{% endfor %}\"\\\n", " \"{% if add_generation_prompt %}\"\\\n", " \"{{ '<start_of_turn>model\\n' }}\"\\\n", " \"{% endif %}\"" ] }, { "cell_type": "markdown", "metadata": { "id": "cq5-Rda-tTeP" }, "source": [ "### Handle new messages\n", "\n", "Now, you need to define a function that will handle new messages (user inputs).\n", "\n", "To make the model context-aware, we need to provide:\n", "\n", "1. System message: The first message of the conversation that guides the behavior of the model during the chat.\n", "1. Chat history: Messages exchanged between the assistant and the user so far.\n", "1. New message: A new message sent by the user.\n", "\n", "All of this information is converted into a list of messages. Then, `apply_chat_template` is used to create the actual prompt (a long string with all the special tokens required by Gemma). The prompt is passed to the tokenizer and then to the model to generate the response." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "15WEWJYkNyHw" }, "outputs": [], "source": [ "from typing import List, Dict\n", "\n", "system_message = \"You're a helpful assistant.\"\n", "\n", "def chat_with_gemma(message: str, history: List[Dict[str, str]],\n", " max_new_tokens: int = 512) -> str:\n", " \"\"\"Chats with the Gemma 2 model and returns the response.\n", "\n", " This function takes a user message and chat history as input, formats them\n", " using the custom chat template, and generates a response using the Gemma 2\n", " pipeline.\n", "\n", " Args:\n", " message: The user's message.\n", " history: The chat history as a list of messages.\n", " max_new_tokens: The maximum number of new tokens to generate.\n", "\n", " Returns:\n", " response: Content generated by the model.\n", " \"\"\"\n", "\n", " # Combine system message, history and the new message into a list of messages.\n", " messages = [\n", " {\"role\": \"system\", \"content\": system_message},\n", " *history,\n", " {\"role\": \"user\", \"content\": message},\n", " ]\n", "\n", " # Apply the chat template to convert it into the prompt (string).\n", " prompt = tokenizer.apply_chat_template(\n", " messages,\n", " add_generation_prompt=True,\n", " tokenize=False\n", " )\n", "\n", " # Generate response using the pipeline defined above.\n", " outputs = pipeline(prompt, max_new_tokens=max_new_tokens)\n", "\n", " # A basic error handling mechanism. If something goes wrong, the\n", " # user will see \"Something went wrong...\" instead of a long error message.\n", " # It's usually a good place to handle quota limits, harmful content, etc.\n", " response = \"_Something went wrong. Please try again._\"\n", " try:\n", " response = outputs[0][\"generated_text\"][len(prompt):]\n", " except:\n", " pass\n", " return response" ] }, { "cell_type": "markdown", "metadata": { "id": "IyUmmZsVtdec" }, "source": [ "### Let's Run It!\n", "\n", "Now, we will use Gradio's `ChatInterface` to create an interactive chat interface that will allow you to chat with our Gemma 2 model! In this case, it will create a window inside Google Colab, but if you run it in a standalone file, it will start an HTTP server, and you will be able to access the chat from your browser." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nOOQMpW6N90V" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n", "\n", "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n", "* Running on public URL: https://29b77c0650271c6a24.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n" ] }, { "data": { "text/html": [ "<div><iframe src=\"https://29b77c0650271c6a24.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import gradio as gr\n", "\n", "gr.ChatInterface(\n", " fn=chat_with_gemma,\n", " type=\"messages\"\n", ").launch()" ] }, { "cell_type": "markdown", "metadata": { "id": "1d9i1kdB7lt8" }, "source": [ "## What's Next?\n", "\n", "That's it! If you're wondering how to make your chatbot even better, check out the following resources:\n", "\n", "- **Explore the Gemma family models:** Visit [Gemma Open Models](https://ai.google.dev/gemma) to learn about the latest updates regarding the Gemma family models, new capabilities, versions, and more.\n", "- **Gradio Customization:** Explore the [Gradio documentation](https://www.gradio.app/docs) to learn about customizing your chat interface, adding new options and features.\n", "- **Share Your Gradio Dashboard:** Check out the [Sharing your Gradio app](https://www.gradio.app/guides/sharing-your-app) page to learn how to safely share your Gradio dashboard with others!" ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "[Gemma_2]Gradio_Chatbot.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }